### Specify which sample are we using ###

#CHOOSE sample size
sample_choice <- "10_percent_sample"

#CHOOSE balance requirement of the panel ("choice can be: unbalanced", "pre_post_1_t" or "pre_post_3_t")
suffix_mild_balanced = "pre_post_3_t"

if (sample_choice == "full") {
  filename <- "data/derived/clean_payroll_provider_firm_cz.rds"
} else if (sample_choice == "10_percent_sample") {
  filename = "data/derived/clean_payroll_provider_firm_cz_10_percent_sample.rds"
} else if (sample_choice == "1_percent_sample") {
  filename = "data/derived/clean_payroll_provider_firm_cz_1_percent_sample.rds"
}

#----------------------------#
#### --- LOAD DATASET --- ####
#----------------------------#
#I'll proceed as follows:
#I'll load the stacked dataset, collect all the unique firm identifiers from it (the ones that we care about throughout the regression analysis)
#Keep only those firms in the "raw" clean payroll_provider data
#And do a winsorizing and the avg_wage_exact variable

#Load stacked dataset
stacked_data = readRDS(paste0("data/derived/mild_balanced_", suffix_mild_balanced, "_stacked_nonpolicy_firm_payroll_provider_dataset_", sample_choice, ".rds"))

#Load clean_payroll_provider data
clean_data = readRDS(filename)

#Collect list of unique firm identifiers from stacked dataset
firms_list = unique(stacked_data$clt_client) %>% as.vector()

#Keep only the above firms in the clean data
df = clean_data %>% 
  filter(clt_client %in% firms_list) %>%
  #Winsorize the avg_wage_exact variable at the P1 and P99
  mutate(avg_wage_exact_wins_1_99 = DescTools::Winsorize(avg_wage_exact, probs = c(.1, .99)),
         year = as.numeric(substr(month,1,4))) %>%
  filter(year >= 2013 & year <= 2023)


#-------------------------------#
#### --- BUILD FUNCTIONS --- ####
#-------------------------------#
## Function to average wage across years for different groups
plot_avg_wage_evolution <- function(df, agg_vars, plot_group_id, 
                                    groups_to_plot = c(),save_plot=TRUE, 
                                    title = "Evolution of Average Wage", subtitle = NULL,
                                    legend_title = "Industry codes", plot_filename="avg_wage.pdf",
                                    hide_legend = FALSE, last_point_labels = FALSE) {
  
  # Convert the group variable to a factor to ensure ggplot constructs labels correctly
  df[[plot_group_id]] <- as.factor(df[[plot_group_id]])
  
  # Filter data to include only the specified groups for plotting
  if (is.null(groups_to_plot) == FALSE) {
    df <- df %>%
      filter(!!sym(plot_group_id) %in% groups_to_plot)
  }
  
  # Calculate the average wage (weighted by share of total employment in group's total employment)
  df <- df %>%
    group_by(across(all_of(agg_vars))) %>%
    mutate(tot_emp_group = sum(tot_emp),
           weight_obs = tot_emp/tot_emp_group) %>%
    summarize(avg_wage_group = sum(weight_obs*avg_wage_exact_wins_1_99)) %>%
    ungroup()
  
  # Construct sequence of years to be plotted
  year_range <- range(df$year)
  year_seq <- seq(year_range[1],year_range[2],by=1)
  
  ## Create the plot
  
  # Take into account that if there is one agg_var then there are no colors by group
  if (length(agg_vars) == 1) {
    p <- ggplot(df, aes_string(x = "year", y = "avg_wage_group"))
  } else {
    p <- ggplot(df, aes_string(x = "year", y = "avg_wage_group", color = plot_group_id, group = plot_group_id))
  }
  
  p <- p +
    geom_line(size = 1.2) + # Increase line size for better visibility
    geom_point(size = 2) + # Add points at each data value
    geom_text(aes(label= round(avg_wage_group,1)),size=3,vjust=-0.5) + # Label each point with wage value
    scale_color_manual(values=colors, name = legend_title) +
    scale_x_continuous(breaks = year_seq , labels=as.integer(year_seq)) + # Ensure years are shown as integers
    labs( title = title,
          subtitle = subtitle,
          x = "Year",
          y = "Average Wage") +
    theme_minimal(base_size = 15) # Increase base size for better visibility
  
  # Add group labels at the end of each line   
  if (last_point_labels) {
    last_data_points <- df %>% group_by(!!sym(plot_group_id)) %>% filter(year == max(year))
    p <- p + geom_text(data = last_data_points,
                       aes(label = !!sym(plot_group_id)),
                       hjust = -0.3, size = 5)
  }
  
  # Hide legend
  if (hide_legend) {
    p <- p + theme(legend.position = "none")
  }
  
  # Save the plot if specified
  if (save_plot) {
    ggsave(plot_filename, plot=p, device="pdf",width=12,height=7)
  }
  print(p)
}


plot_employment_by_wage_bin <- function(data, year_to_plot,plot_filename) {
  
  # Calculate the average wage (weighted by share of total employment in group's total employment)
  data <- data %>%
    group_by(wage_bin, year) %>%
    summarise(employment_bin_year = sum(employment), na.rm = TRUE) %>%
    ungroup() %>%
    group_by(year) %>%
    mutate(tot_emp_year = sum(employment_bin_year)) %>%
    mutate(share_emp_year = employment_bin_year/tot_emp_year)
  
  
  # Define the correct order of the wage_bin categories
  wage_bin_levels <- c("lt8", "8", "9", "10", "11", "12", "13", "14", 
                       "15", "16", "17", "18", "19", "20", "21", "22", 
                       "23", "24", "25", "26", "27", "28", "29", "gt30")
  
  # Filter the data for the specified year
  data_filtered <- data %>%
    filter(year == year_to_plot) %>%
    mutate(wage_bin = factor(wage_bin, levels = wage_bin_levels))  # Ensure wage_bin is a factor with the correct order
  
  # Create the plot
  p <- ggplot(data_filtered, aes(x = wage_bin, y = share_emp_year)) +
    geom_bar(stat = "identity", fill = "skyblue") +  # Bar plot with custom color
    labs(title = paste("Share of Employment by Wage Bin in", year_to_plot),
         x = "Wage Bin",
         y = "Share of Employment") +
    theme_minimal(base_size = 15) +  # Minimal theme with larger base font size
    theme(axis.text.x = element_text(angle = 45, hjust = 1))  # Rotate x-axis labels for better readability
  
  ggsave(plot_filename, plot=p, device="pdf",width=12,height=7)
  
  print(p)
  
}


#-------------------------------#
#### --- GENERATE OUTPUT --- ####
#-------------------------------#
## Evolution of average wage plots ##

# Plot: average wage by industry
plot_avg_wage_evolution(df, 
                        agg_vars = c("year", "clt_industry_code_3d"), 
                        plot_group_id = "clt_industry_code_3d", 
                        groups_to_plot = c(452, 722, 454),
                        plot_filename = paste0("figures_tables/avg_wage_3_industries_", suffix_mild_balanced, "_", sample_choice, ".pdf"))

# For a single line combining multiple industry codes
plot_avg_wage_evolution(df, 
                        agg_vars = c("year"), 
                        plot_group_id = "clt_industry_code_3d", 
                        subtitle = "Industries included are: 561, 452, 722, 445, 448, 492, 454, 444, 451",
                        hide_legend = TRUE,
                        groups_to_plot = c(561, 452, 722, 445, 448, 492, 454, 444, 451),
                        plot_filename = paste0("figures_tables/avg_wage_multiple_industries_", suffix_mild_balanced, "_",sample_choice,".pdf"))

# For a single line combining all industry codes
plot_avg_wage_evolution(df, 
                        agg_vars = c("year"), 
                        plot_group_id = "clt_industry_code_3d", 
                        subtitle = "(All industries)",
                        hide_legend = TRUE,
                        plot_filename = paste0("figures_tables/avg_wage_all_industries_", suffix_mild_balanced,"_", sample_choice, ".pdf"))


## Distribution of employment across wage bins

df_long <- df %>%
  select(year, month, czone, clt_client, clt_state, clt_industry_code_3d, starts_with("wage_")) %>%
  pivot_longer(cols=starts_with("wage_"), names_to = "wage_bin", values_to = "employment") %>%
  mutate(wage_bin = sub("wage_","",wage_bin))

plot_employment_by_wage_bin(df_long, year_to_plot = 2013, plot_filename = paste0("figures_tables/employment_distribution_2013_", suffix_mild_balanced, "_", sample_choice,".pdf"))

plot_employment_by_wage_bin(df_long, year_to_plot = 2018, plot_filename = paste0("figures_tables/employment_distribution_2018_", suffix_mild_balanced, "_", sample_choice,".pdf"))

plot_employment_by_wage_bin(df_long, year_to_plot = 2023, plot_filename = paste0("figures_tables/employment_distribution_2023_", suffix_mild_balanced,"_",sample_choice,".pdf"))


#EDIT: Add Employment share by industry figure
#------------------------------------#
#### Employment share by industry ####
#------------------------------------#
#Load dataset with codes for industries of interest
industries <- read_xlsx("data/official_statistics/NAICS 2digit.xlsx") %>% 
  select(industry_code, 
         industry_name = Name) ## NAICS 2-digit industry codes

#Collapse dataset at the industry level to compute the employment shares by industry
emp_share_industry = df %>%
  #Create 2-digits industry code variable
  mutate(industry_code = as.numeric(substr(clt_industry_code_3d, start = 1, stop = 2))) %>% 
  #Merge with industries dataset
  left_join(industries, by = "industry_code") %>%
  filter(!is.na(industry_name)) %>% #drop observations not matched in the merge (ie, firms whose industries are not in the industries dataset)
  #Collapse the data by industry and compute the employment shares
  group_by(industry_name) %>%
  summarise(emp_total_ind = sum(tot_emp)) %>%
  ungroup() %>%
  mutate(emp_total_all = sum(emp_total_ind),
         emp_share_ind = emp_total_ind/emp_total_all)

#Plot
ggplot(emp_share_industry,
       aes(x = reorder(industry_name, emp_share_ind), y = 100*emp_share_ind)) +
  geom_col(fill = colors[1]) +
  geom_text(aes(label = paste0(round(100*emp_share_ind, 2), "%")), 
            vjust = -.2) +
  scale_x_discrete(labels = function(x) str_wrap(x, width = 25)) + #break x label text lines with more than 25 characters
  theme_minimal() +
  theme(#panel.grid.major = element_blank(),
    panel.background = element_blank(),
    text = element_text(size = 13),
    axis.line = element_line(colour = "black"),
    axis.text = element_text(colour = "black"),
    axis.text.x = element_text(angle = 60, hjust = 1),
  )  +
  labs(x = "",
       y = "Percent")

#Output it to pdf format
ggsave("figures_tables/fige3_employment_share_by_industry.pdf", 
       device="pdf", width=18, height=7)

#EDIT: add firm and employment distribution across firm size bins
#--------------------------------------------------------------#
#### Firm and employment distribution across firm size bins ####
#--------------------------------------------------------------#
firm_size_dist = df %>%
  #Construct firm size bins categories based on total firm employment
  mutate(tot_emp_class = case_when(is.na(tot_emp) ~ "NA",
                                   tot_emp < 5 ~ "1-4",
                                   tot_emp < 10 ~ "5-9",
                                   tot_emp < 20 ~ "10-19",
                                   tot_emp < 50 ~ "20-49",
                                   tot_emp < 100 ~ "50-99",
                                   tot_emp < 250 ~ "100-249",
                                   tot_emp < 500 ~ "250-499",
                                   tot_emp < 1000 ~ "500-999",
                                   T ~ "1000 or more")) %>%
  #Collapse the data at the firm size bin level and compute the firm and employment shares by bin
  group_by(tot_emp_class) %>%
  summarise(tot_firms_by_bin = n(),
            tot_emp_by_bin = sum(tot_emp)) %>%
  ungroup() %>%
  mutate(sh_firms_by_bin = tot_firms_by_bin/sum(tot_firms_by_bin),
         sh_emp_by_bin = tot_emp_by_bin/sum(tot_emp_by_bin)) %>%
  arrange(-sh_firms_by_bin) %>%
  #Reshape to wide format
  mutate(to_tex_sh_firms_by_bin = paste0(round(100*sh_firms_by_bin, 2), "%"),
         to_tex_sh_emp_by_bin = paste0(round(100*sh_emp_by_bin, 2), "%")) %>%
  select(tot_emp_class, to_tex_sh_firms_by_bin, to_tex_sh_emp_by_bin) %>%
  rename("Share of firms" = to_tex_sh_firms_by_bin,
         "Share of employment" = to_tex_sh_emp_by_bin) %>%
  gather(key = "Variable", value = "value", -tot_emp_class) %>%
  spread(key = tot_emp_class, value = value) %>%
  select(`Variable`, `1-4`, `5-9`, `10-19`, `20-49`) #arrange into expected order

#Produce tex table
latex_table = knitr::kable(firm_size_dist, format = "latex")
write(latex_table, file = "figures_tables/tablee4_firm_dist_analysis.tex")

